import math
import copy
import gym
import random
import numpy as np
import statistics
import pickle

# Import your updated custom/stochastic envs
import Continuous_CartPole
import Continuous_Pendulum
import continuous_mountain_car
import continuous_acrobot
import improved_hopper
import improved_ant
import improved_walker2d

from SnapshotENV import SnapshotEnv
# Suppose you have progressive widening logic in "pw_module.py"
# from pw_module import ProgressiveWideningMCTS   # or something similar

# 1) environment IDs
env_names = [
    "Continuous-CartPole-v0",
    "StochasticPendulum-v0",
    "StochasticMountainCarContinuous-v0",
    "StochasticContinuousAcrobot-v0",
    "ImprovedHopper-v0",
    "ImprovedWalker2d-v0",
    "ImprovedAnt-v0"
]

# 2) For each environment, define optional noise scales or other kwargs
ENV_NOISE_CONFIG = {
    "Continuous-CartPole-v0": {
        "action_noise_scale": 0.05, #0.05
        "dynamics_noise_scale": 0.5, #0.01
        "obs_noise_scale": 0.0
    },
    "StochasticPendulum-v0": {
        "action_noise_scale": 0.02, #0.02,
        "dynamics_noise_scale": 0.1, #0.01,
        "obs_noise_scale": 0.01
        # or pass "g": 9.8 if you want a different gravity, etc.
    },
    "StochasticMountainCarContinuous-v0": {
        "action_noise_scale":  0.05, #0.03,
        "dynamics_noise_scale": 0.5, #0.01,
        "obs_noise_scale": 0.0
    },
    "StochasticContinuousAcrobot-v0": {
        "action_noise_scale": 0.05, #0.05,
        "dynamics_noise_scale": 0.7,  #0.01,
        "obs_noise_scale": 0.01
    },
    "ImprovedHopper-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    },
    "ImprovedWalker2d-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    },
    "ImprovedAnt-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    }
}

# 3) Global config
num_seeds = 20
TEST_ITERATIONS = 150
discount = 0.99
MAX_MCTS_DEPTH = 100

# We'll do iteration counts in a geometric progression
base = 1000 ** (1.0 / 15.0)
samples = [int(3 * (base ** i)) for i in range(16)]
samples_to_use = samples[0:6]

# -------------------------------------------------------------------------
# 4) Node class with minimal progressive widening approach
#    Enhanced for higher-dimensional environments
# -------------------------------------------------------------------------
class NodePW:
    def __init__(self, snapshot, obs, is_done, parent, depth, min_action, max_action, dim):
        self.parent = parent
        self.snapshot = snapshot
        self.obs = obs
        self.is_done = is_done
        self.depth = depth

        self.children = []  # store child nodes
        self.actions = []   # store the actions tried
        self.immediate_reward = 0
        self.min_action = min_action
        self.max_action = max_action
        self.dim = dim

        self.visit_count = 0
        self.value_sum = 0.0

    def get_mean_value(self):
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count

    def select_child(self):
        # UCB selection logic:
        best_child = None
        best_score = -1e9
        N = max(1, self.visit_count)
        c = 1.414  # exploration constant (sqrt(2))
        for child in self.children:
            # child UCB
            mean_val = child.get_mean_value()
            n = max(child.visit_count, 1)
            ucb = mean_val + c * math.sqrt(math.log(N) / n)
            if ucb > best_score:
                best_score = ucb
                best_child = child
        return best_child

    def expand(self):
        # Progressive widening with dimension-aware budgets
        # K = number of children to create based on visit count and dimension
        if self.dim <= 3:
            # For low-dimensional spaces, allow more children
            K = max(5, int(0.5 * math.sqrt(self.visit_count)))
        else:
            # For high-dimensional spaces, be more conservative
            K = max(3, int(0.3 * math.sqrt(self.visit_count)))

        if len(self.children) < K:
            # create a random action in [min_action, max_action]^dim
            action = []
            for _ in range(self.dim):
                rnd = random.uniform(self.min_action, self.max_action)
                action.append(rnd)

            # Build child node
            child_node = NodePW(
                snapshot=None,
                obs=None,
                is_done=False,
                parent=self,
                depth=self.depth + 1,
                min_action=self.min_action,
                max_action=self.max_action,
                dim=self.dim
            )
            child_node.action = tuple(action)
            self.children.append(child_node)
            self.actions.append(tuple(action))  # Keep track of actions
            return child_node
        return None

    def rollout(self, env, max_depth):
        """Random rollout from current state for high-dimensional environments"""
        if self.depth >= max_depth:
            return 0.0

        if self.snapshot:
            env.load_snapshot(self.snapshot)

        total = 0.0
        discount_factor = 1.0

        for _ in range(max_depth - self.depth):
            # Generate random action
            action = tuple([random.uniform(self.min_action, self.max_action)
                            for _ in range(self.dim)])

            obs, r, done, _ = env.step(action)
            total += r * discount_factor
            discount_factor *= discount

            if done:
                break

        return total

    def selection(self, env, max_depth):
        if self.is_done or self.depth >= max_depth:
            return 0.0

        # Progressive widening logic
        new_child = self.expand()
        if new_child:
            # We actually step the environment for the new child
            # FIXED: Check if we have a valid snapshot before calling get_result
            if self.snapshot is None:
                # This should only happen at the root node on the first iteration
                # Get current snapshot
                self.snapshot = env.get_snapshot()

            action_result = env.get_result(self.snapshot, new_child.action)
            new_child.snapshot = action_result.snapshot
            new_child.obs = action_result.next_state
            new_child.is_done = action_result.is_done
            new_child.immediate_reward = action_result.reward

            # For high-dimensional environments, use rollout
            if self.dim > 6:
                rollout_value = new_child.rollout(env, max_depth)
                value = rollout_value
            else:
                value = new_child.selection(env, max_depth)

            child_ret = new_child.immediate_reward + value
            # update self stats
            self.visit_count += 1
            self.value_sum += child_ret
            return child_ret
        else:
            # select existing child
            best = self.select_child()
            if best is None:
                return 0.0
            child_ret = best.immediate_reward + best.selection(env, max_depth)
            self.visit_count += 1
            self.value_sum += child_ret
            return child_ret

    def delete_subtree(self):
        for c in self.children:
            c.delete_subtree()
        del self

# -------------------------------------------------------------------------
# 5) Main experiment script
# -------------------------------------------------------------------------
if __name__ == "__main__":
    results_filename = "pw_results.txt"
    f_out = open(results_filename, "a")

    for envname in env_names:
        # Create environment with noise settings
        stoch_kwargs = ENV_NOISE_CONFIG.get(envname, {})
        base_env = gym.make(envname, **stoch_kwargs).env

        # Figure out dimension and action ranges
        if envname == "Continuous-CartPole-v0":
            min_action = base_env.min_action
            max_action = base_env.max_action
            dim = 1
            max_depth = 50
        elif envname == "StochasticPendulum-v0":
            min_action = -2.0
            max_action = 2.0
            dim = 1
            max_depth = 50
        elif envname == "StochasticLunarLanderContinuous-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 2
            max_depth = 100
        elif envname == "StochasticMountainCarContinuous-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 1
            max_depth = 50
        elif envname == "StochasticContinuousAcrobot-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 1
            max_depth = 50
        elif envname == "ImprovedHopper-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 3
            max_depth = 100
        elif envname == "ImprovedWalker2d-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 6
            max_depth = 100
        elif envname == "ImprovedAnt-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 8
            max_depth = 100
        else:
            min_action = -1.0
            max_action = 1.0
            dim = 1
            max_depth = 50

        print(f"\nEnvironment: {envname}")
        print(f"Action dimension: {dim}")
        print(f"Max depth: {max_depth}")

        # Wrap with SnapshotEnv
        planning_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)
        root_obs_ori = planning_env.reset()
        root_snapshot_ori = planning_env.get_snapshot()

        # We'll run for each ITERATIONS in samples_to_use
        for ITERATIONS in samples_to_use:
            seed_returns = []

            for seed_i in range(num_seeds):
                random.seed(seed_i)
                np.random.seed(seed_i)

                # copy snapshot
                root_obs = copy.copy(root_obs_ori)
                root_snapshot = copy.copy(root_snapshot_ori)

                # build root node
                root = NodePW(
                    snapshot=root_snapshot,
                    obs=root_obs,
                    is_done=False,
                    parent=None,
                    depth=0,
                    min_action=min_action,
                    max_action=max_action,
                    dim=dim
                )

                # plan
                for _ in range(ITERATIONS):
                    root.selection(planning_env, max_depth)

                # test
                test_env = pickle.loads(root_snapshot)
                total_reward = 0.0
                current_discount = 1.0
                done = False

                for i in range(TEST_ITERATIONS):
                    # choose best child from root
                    if len(root.children) == 0:
                        # no children => random
                        best_action = tuple([random.uniform(min_action, max_action) for _ in range(dim)])
                    else:
                        best_child = max(root.children, key=lambda c: c.get_mean_value())
                        best_action = best_child.action

                    # step test_env
                    s, r, done, _ = test_env.step(best_action)
                    total_reward += r * current_discount
                    current_discount *= discount

                    if done:
                        test_env.close()
                        break

                    # prune other children
                    for c in list(root.children):
                        if c.action != best_action:
                            c.delete_subtree()
                            root.children.remove(c)

                    # re-root
                    chosen = None
                    for c in root.children:
                        if c.action == best_action:
                            chosen = c
                            break
                    if chosen is None:
                        # create a new child node if needed
                        chosen = NodePW(
                            snapshot=None,
                            obs=None,
                            is_done=False,
                            parent=None,
                            depth=0,
                            min_action=min_action,
                            max_action=max_action,
                            dim=dim
                        )
                    # force it as new root
                    chosen.parent = None
                    root = chosen
                    root.depth = 0

                    # re-plan
                    for _ in range(ITERATIONS):
                        root.selection(planning_env, max_depth)

                if not done:
                    test_env.close()

                seed_returns.append(total_reward)

            mean_return = statistics.mean(seed_returns)
            std_return = statistics.pstdev(seed_returns)
            interval = 2.0 * std_return

            msg = (f"Env={envname}, ITER={ITERATIONS}: "
                   f"Mean={mean_return:.3f} ± {interval:.3f} "
                   f"(over {num_seeds} seeds)")
            print(msg)
            f_out.write(msg + "\n")
            f_out.flush()

    f_out.close()
    print(f"Done! Results saved to", results_filename)
